from typing import Optional, Type, Sequence
from functools import partial

from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.distributed.fsdp import MixedPrecision, CPUOffload, ShardingStrategy
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn as nn
import torch

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import ModuleWrapPolicy


def wrap_in_ddp(
    module: nn.Module,
    local_rank: int,
):
    return DDP(module=module.to(f"cuda:{local_rank}"), device_ids=[local_rank])


def wrap_in_fsdp(
    module: nn.Module,
    local_rank: Optional[int],
    param_precision: torch.dtype,
    cast_inputs: bool,
    mixed_precision_ignored_classes: Sequence[Type[nn.Module]],
    offload_params: bool,
    print_model: bool,
    min_num_params: int,
    modules_to_wrap: tuple[Type[nn.Module]],
    is_logging_process: bool,
):
    assert (modules_to_wrap is None and min_num_params is not None) or (
        modules_to_wrap is not None and min_num_params is None
    ), "The FSDP arguments `modules_to_wrap` and `min_num_params` are mutually exclusive. Either supply one, or the other."

    if modules_to_wrap is not None:
        wrap_policy = ModuleWrapPolicy(modules_to_wrap)
    else:
        wrap_policy = (
            partial(size_based_auto_wrap_policy, min_num_params=min_num_params)
            if min_num_params is not None
            else size_based_auto_wrap_policy
        )

    wrapped = FSDP(
        module,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        device_id=local_rank,
        mixed_precision=MixedPrecision(
            param_dtype=param_precision,
            reduce_dtype=param_precision,
            cast_forward_inputs=cast_inputs,
            _module_classes_to_ignore=mixed_precision_ignored_classes,
        ),
        cpu_offload=CPUOffload(offload_params=offload_params),
        auto_wrap_policy=wrap_policy,
    )

    if print_model and is_logging_process:
        print("------- MODEL AFTER WRAPPING IN FSDP -------")
        print(wrapped)
        print("--------------------------------------------")

    return wrapped
